import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from parser1 import args

class MultiHeadSelfMaskAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(MultiHeadSelfMaskAttention, self).__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.head_size = embed_size // num_heads
        self.image_feats = np.load(args.data_path + '{}/image_feat.npy'.format(args.dataset))
        self.text_feats = np.load(args.data_path + '{}/text_feat.npy'.format(args.dataset))
        img_size = self.image_feats.shape[1]
        txt_szie = self.text_feats.shape[1]
        # 权重矩阵
        self.i_WQ = nn.Linear(img_size, embed_size, bias=False)
        self.i_WK = nn.Linear(img_size, embed_size, bias=False)
        self.i_WV = nn.Linear(img_size, embed_size, bias=False)
        self.i_WO = nn.Linear(embed_size, embed_size, bias=False)

        self.t_WQ = nn.Linear(txt_szie, embed_size, bias=False)
        self.t_WK = nn.Linear(txt_szie, embed_size, bias=False)
        self.t_WV = nn.Linear(txt_szie, embed_size, bias=False)
        self.t_WO = nn.Linear(embed_size, embed_size, bias=False)

    def compute(self,type,x):
        batch_size, seq_len, _ = x.size()
        if type == "text":
            # 计算 Q、K、V
            Q = self.t_WQ(x).view(batch_size, seq_len, self.num_heads, self.head_size)
            K = self.t_WK(x).view(batch_size, seq_len, self.num_heads, self.head_size)
            V = self.t_WV(x).view(batch_size, seq_len, self.num_heads, self.head_size)
            print("Qshape:", Q.shape)
        elif type == "image":
            # 计算 Q、K、V
            Q = self.i_WQ(x).view(batch_size, seq_len, self.num_heads, self.head_size)
            K = self.i_WK(x).view(batch_size, seq_len, self.num_heads, self.head_size)
            V = self.i_WV(x).view(batch_size, seq_len, self.num_heads, self.head_size)
            print("Qshape:", Q.shape)

        # 调整形状
        Q = Q.permute(2, 0, 1, 3)  # (num_heads, batch_size, seq_len, head_size)
        K = K.permute(2, 0, 1, 3)
        V = V.permute(2, 0, 1, 3)

        # 计算注意力得分
        att_scores = torch.matmul(Q, K.permute(0, 1, 3, 2)) / torch.sqrt(torch.tensor(self.head_size).float())
        print(att_scores.shape)
        print(mask.shape)
        # 在这里应用 mask
        att_scores = att_scores.masked_fill(mask == 0, float('-inf'))

        # 计算注意力权重
        att_weights = F.softmax(att_scores, dim=-1)

        # 使用注意力权重对 V 进行加权和
        att_output = torch.matmul(att_weights, V).permute(1, 2, 0, 3).contiguous().view(batch_size, seq_len, -1)

        # 进行线性变换
        output = self.WO(att_output)
        return output

    def forward(self, x,items):
        image = self.compute('image',x,items)
        text = self.compute('text',x,items)

        return output

# 示例用法
embed_size = 256
num_heads = 4
seq_len = 10
batch_size = 32

# 创建模型
model = MultiHeadSelfMaskAttention(embed_size, num_heads)

# 生成输入和 mask
x = torch.randn(batch_size, seq_len, embed_size)
mask = torch.randint(2, size=(batch_size, seq_len,seq_len))  # 0 表示需要 mask 的位置

# 计算输出
output = model(x, mask)
print(x.shape)
print(output.shape)
